from mynumpy import *

class estimator_antithetic:

    def __init__(self,base_est):
        self.base_est  = base_est
        self.label     = 'ant-'+self.base_est.label
        self.omega_dim = base_est.omega_dim
        self.w_dim     = base_est.w_dim

    def sample_omega(self,num):
        return self.base_est.sample_omega(num)

    def logR(self,omega,w):
        #omega1 = (1+omega)/2
        #omega2 = (1-omega)/2
        omega1 = omega
        omega2 = 1-omega
        logR1 = self.base_est.logR(omega1,w)
        logR2 = self.base_est.logR(omega2,w)

        #return exp(1/2*(exp(logR1)+exp(logR2)))
        #return exp(exp(log(.5))*(exp(logR1) + exp(log(.5))exp(logR2)))
        #return exp( (exp(log(.5)+logR1) + exp(log(.5)+logR2)))
        return np.logaddexp(log(.5)+logR1,log(.5)+logR2)

    def sample_zs(self,omega,w):
        #omega1 = (1+omega)/2
        #omega2 = (1-omega)/2
        omega1 = omega
        omega2 = 1-omega
        z1 = self.base_est.sample_zs(omega1,w)
        z2 = self.base_est.sample_zs(omega2,w)
        print('z1.shape')
        return np.vstack([z1,z2])

    def sample_z(self,omega,w):
        #omega1 = (1+omega)/2
        #omega2 = (1-omega)/2
        omega1 = omega
        omega2 = 1-omega
        logR1 = self.base_est.logR(omega1,w)
        logR2 = self.base_est.logR(omega2,w)
        
        z1 = self.base_est.sample_z(omega1,w)
        z2 = self.base_est.sample_z(omega2,w)

        p1 = exp(logR1 - np.logaddexp(logR1,logR2))
        r = rand(omega.shape[-1])

        z = (r<p1)*z1 + (r>=p1)*z2
        return z
